from openai import OpenAI
import logging
from app.base.logger import setup_logger
from app.utils.config_manager import config
from typing import List, Dict, Any, Optional

# 使用新的日志配置
logger = setup_logger("embedding_service")

class EmbeddingService:
    """
    文本嵌入服务，用于将文本转换为向量表示
    """
    def __init__(self):
        self.model_name = None
        self.api_key = None
        self.api_base = None
        self.initialized = False
        self.error_message = None

    def initialize(self, model_name=None, api_key=None, api_base=None):
        """
        初始化嵌入服务
        
        Args:
            model_name: 模型名称
            api_key: API密钥
            api_base: API基础URL
        
        Returns:
            bool: 初始化是否成功
        """
        try:
            logger.info("初始化嵌入服务")
            
            # 从配置中获取参数
            self.model_name = config.embedding_config.get("model_name")
            api_key = config.embedding_config.get("api_key")
            api_base = config.embedding_config.get("api_base")
            
            # 创建OpenAI客户端实例
            self.client = OpenAI(
                api_key=api_key,
                base_url=api_base  # 确保这里包含 "/v1" 或在配置中已包含
            )
                
            self.initialized = True
            return True
        except Exception as e:
            self.error_message = f"初始化失败: {str(e)}"
            return False
    
    def embed_text(self, text: str) -> List[float]:
        """
        将文本转换为向量表示
        
        Args:
            text: 输入文本
        
        Returns:
            List[float]: 向量表示
        """
        if not self.initialized:
            logger.error("嵌入服务未初始化，无法执行嵌入")
            return None
        
        try:
            if not text or text.strip() == "":
                logger.warning("嵌入文本为空")
                return None
            
            logger.debug(f"对文本进行嵌入: {text[:50]}...")
            
            # 调用OpenAI API获取嵌入
            response = self.client.embeddings.create(
                model=self.model_name,
                input=text
            )
            
            # 提取嵌入向量
            embedding = response.data[0].embedding
            
            logger.debug(f"嵌入完成，向量维度: {len(embedding)}")
            return embedding
            
        except Exception as e:
            logger.exception(f"文本嵌入失败: {str(e)}")
            self.error_message = f"嵌入异常: {str(e)}"
            return None
    
    def batch_embed_texts(self, texts: List[str]) -> List[List[float]]:
        """
        批量将多个文本转换为向量表示
        
        Args:
            texts: 输入文本列表
        
        Returns:
            List[List[float]]: 向量表示列表
        """
        if not self.initialized:
            logger.error("嵌入服务未初始化，无法执行嵌入")
            return None
        
        try:
            if not texts or len(texts) == 0:
                logger.warning("嵌入文本列表为空")
                return []
            
            # 筛选有效文本
            valid_texts = [text for text in texts if text and text.strip() != ""]
            if len(valid_texts) == 0:
                logger.warning("筛选后的嵌入文本列表为空")
                return []
            
            logger.debug(f"批量嵌入 {len(valid_texts)} 个文本")
            
            # 调用OpenAI API批量获取嵌入
            response = self.client.embeddings.create(
                model=self.model_name,
                input=valid_texts
            )
            
            # 提取所有嵌入向量
            embeddings = [item.embedding for item in response.data]
            
            logger.debug(f"批量嵌入完成，共 {len(embeddings)} 个向量")
            return embeddings
            
        except Exception as e:
            logger.exception(f"批量文本嵌入失败: {str(e)}")
            self.error_message = f"批量嵌入异常: {str(e)}"
            return None
    
    def get_error_message(self):
        """获取最后一次错误信息"""
        return getattr(self, 'error_message', "未知错误")

# 创建全局实例
embedding_service = EmbeddingService() 